{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "5d39ed87",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import anthropic\n",
    "import os\n",
    "import json\n",
    "import re\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "\n",
    "api_key = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8ec1bbe2",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset('huseinzol05/malaysian-dialect-qa', split = 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "42bfeca0",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_lang = load_dataset('huseinzol05/malaysian-dialect-qa-lang', split = 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "71070129",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "140"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "questions = []\n",
    "for i in range(len(dataset)):\n",
    "    q = dataset[i]['question'] + '\\n\\nAfter that, put your final answer within $\\\\boxed{}$.'\n",
    "    questions.append((i, q))\n",
    "    \n",
    "len(questions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ac95e1b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0,\n",
       " 'Lepas ujan jangan maen lari-lari, kan biyak.\\n\\nterjemah ke melayu baku\\n\\nAfter that, put your final answer within $\\\\boxed{}$.')"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "questions[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ab537b0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "folder = 'antrophic-sonnet4-reasoning'\n",
    "!rm -rf {folder}\n",
    "!mkdir {folder}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e97d1329",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_answer(row, repeat = 5, thinking = False):\n",
    "    no, q = row\n",
    "    if thinking:\n",
    "        thinking_mode = {\n",
    "            \"type\": \"enabled\",\n",
    "            \"budget_tokens\": 6000\n",
    "        }\n",
    "    else:\n",
    "        thinking_mode = {\n",
    "            \"type\": \"disabled\",\n",
    "        }\n",
    "    for k in range(repeat):\n",
    "        filename = os.path.join(folder, f'{no}-{k}.json')\n",
    "        if os.path.exists(filename):\n",
    "            continue\n",
    "            \n",
    "        client = anthropic.Anthropic(\n",
    "            api_key=api_key,\n",
    "        )\n",
    "        \n",
    "        for _ in range(5):\n",
    "\n",
    "            message = client.messages.create(\n",
    "                model=\"claude-sonnet-4-20250514\",\n",
    "                max_tokens=8192,\n",
    "                temperature=1,\n",
    "                messages=[\n",
    "                    {\n",
    "                        \"role\": \"user\",\n",
    "                        \"content\": [\n",
    "                            {\n",
    "                                \"type\": \"text\",\n",
    "                                \"text\": q\n",
    "                            }\n",
    "                        ]\n",
    "                    }\n",
    "                ],\n",
    "                thinking=thinking_mode\n",
    "            )\n",
    "            \n",
    "            if thinking:\n",
    "                text = message.content[1].text\n",
    "            else:\n",
    "                text = message.content[0].text\n",
    "            try:\n",
    "                if 'boxed{' in text:\n",
    "                    text = text.split('boxed{')\n",
    "                    text = text[-1].split('}')[0]\n",
    "                    if 'text{' in text:\n",
    "                        text = text.split('text{')[1]\n",
    "                with open(filename, 'w') as fopen:\n",
    "                    json.dump(text, fopen)\n",
    "                break\n",
    "            except Exception as e:\n",
    "                print(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "04c47f27",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 140/140 [2:03:49<00:00, 53.07s/it]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(len(dataset))):\n",
    "    generate_answer(questions[i], thinking = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "9cca8a9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sacrebleu.metrics import CHRF\n",
    "from glob import glob\n",
    "from collections import defaultdict\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "1bd96a59",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 140/140 [00:00<00:00, 1075.11it/s]\n"
     ]
    }
   ],
   "source": [
    "chrf = CHRF()\n",
    "pairs = defaultdict(list)\n",
    "\n",
    "for i in tqdm(range(len(dataset_lang))):\n",
    "    from_lang = dataset_lang[i]['from_lang']\n",
    "    to_lang = dataset_lang[i]['to_lang']\n",
    "    gt = dataset_lang[i]['answer']\n",
    "    pair = f'{from_lang}<>{to_lang}'\n",
    "    files = glob(f'antrophic-sonnet4-reasoning/{i}-*.json')\n",
    "    scores = []\n",
    "    for f in files:\n",
    "        with open(f) as fopen:\n",
    "            d = json.load(fopen)\n",
    "        score = chrf.corpus_score([d], [[gt]]).score\n",
    "        scores.append(score)\n",
    "\n",
    "    max_score = max(scores)\n",
    "    pairs[pair].append(max_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "81fa0c65",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "From: johor To: malay, score: 54.255257504569634\n",
      "From: kedah To: malay, score: 57.449596027901556\n",
      "From: pahang To: malay, score: 56.050299072477785\n",
      "From: negeri sembilan To: malay, score: 54.09057071374\n",
      "From: kelantan To: malay, score: 45.3943144521222\n",
      "From: penang To: malay, score: 64.80883099915431\n",
      "From: melaka To: malay, score: 45.44978279611746\n",
      "From: malay To: johor, score: 47.94515660238058\n",
      "From: malay To: kedah, score: 45.27348894658902\n",
      "From: malay To: pahang, score: 47.98090699950979\n",
      "From: malay To: negeri sembilan, score: 48.028256484596945\n",
      "From: malay To: kelantan, score: 32.13859674498915\n",
      "From: malay To: penang, score: 41.04876819683223\n",
      "From: malay To: melaka, score: 52.83520776382694\n"
     ]
    }
   ],
   "source": [
    "for k, v in pairs.items():\n",
    "    l, r = k.split('<>')\n",
    "    print(f'From: {l} To: {r}, score:', np.mean(v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "32c13153",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "53.9283787951547"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = \"\"\"\n",
    "From: johor To: malay, score: 54.255257504569634\n",
    "From: kedah To: malay, score: 57.449596027901556\n",
    "From: pahang To: malay, score: 56.050299072477785\n",
    "From: negeri sembilan To: malay, score: 54.09057071374\n",
    "From: kelantan To: malay, score: 45.3943144521222\n",
    "From: penang To: malay, score: 64.80883099915431\n",
    "From: melaka To: malay, score: 45.44978279611746\n",
    "\"\"\"\n",
    "\n",
    "scores = []\n",
    "for l in x.split('\\n'):\n",
    "    if 'score:' not in l:\n",
    "        continue\n",
    "    \n",
    "    scores.append(float(l.split('score: ')[1]))\n",
    "    \n",
    "np.mean(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "9257e572",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "45.035768819817804"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = \"\"\"\n",
    "From: malay To: johor, score: 47.94515660238058\n",
    "From: malay To: kedah, score: 45.27348894658902\n",
    "From: malay To: pahang, score: 47.98090699950979\n",
    "From: malay To: negeri sembilan, score: 48.028256484596945\n",
    "From: malay To: kelantan, score: 32.13859674498915\n",
    "From: malay To: penang, score: 41.04876819683223\n",
    "From: malay To: melaka, score: 52.83520776382694\n",
    "\"\"\"\n",
    "\n",
    "scores = []\n",
    "for l in x.split('\\n'):\n",
    "    if 'score:' not in l:\n",
    "        continue\n",
    "    \n",
    "    scores.append(float(l.split('score: ')[1]))\n",
    "    \n",
    "np.mean(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e6aaf0a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3.10",
   "language": "python",
   "name": "python3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
